import torch
from torch import nn
bn_input = nn.BatchNorm1d(1, momentum=0.5)
aa = torch.rand(3,1)
print(aa)
print(bn_input(aa))